"""
原图像大小为224x224 (1,3,224,224)
将图像分为 16x16 大小的块 (1,3,16,16,14x14) => (1,768,196) # (batch_size,seq_len,vocab_size)
"""
import torch
from torch import nn

conv_layer = nn.Conv2d(3, 768, 16, 16)

inputs = torch.randn(1, 3, 224, 224)
outputs = conv_layer(inputs)
print(outputs.shape)
outputs = outputs.reshape(-1, 768, 196)
print(outputs.shape)
